/*
    Copyright 2006-2011 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// Contains a general function for shooting rays with multiple threads.

#ifndef __shoot__
#define __shoot__

#include "config.h"
#include <iostream> 
#include <fstream> 
#include "counter.h"
#include "boost/thread/thread.hpp"
#include "boost/shared_ptr.hpp"
#include "boost/lexical_cast.hpp"
#include "hpm.h"
#include <vector> 
#include "xfer.h"
#include "terminator.h"
#include "threadlocal.h"
#include "mpi_master_impl.h"

namespace mcrx {
  template<typename, typename> class shooter_thread;
  class scatter_shooter;
  class integration_shooter;
  class nonscatter_shooter;
  class bidir_shooter;

  template<typename T_xfer>
  long calculate_nrays(T_xfer& x, long n_rays_desired);

  template<typename T_xfer, typename T_shooter>
  bool
  shoot (T_xfer& xx,
	 T_shooter s, 
	 const terminator& t,
	 std::vector<T_rng::T_state>& rng_states,
	 long n_rays_desired,
	 long& n_rays,
	 int n_threads,
	 bool bind_threads,
	 int bank_size,
	 int hpm_no,
	 std::string hpm_string);
}


/** This function distributes the total number of rays requested
    onto the different tasks in relation to their luminosity. If the
    tasks share the same sources, they will just divide the number
    of rays evenly. If they have different sources, each gets the
    number of rays in relation to their total luminosity. The
    luminosity of the rays is normalized to the number of rays on
    each task at creation, so this preserves the total
    luminosity. */
template<typename T_xfer>
long mcrx::calculate_nrays(T_xfer& x, long n_rays_desired)
{

  // Get total luminosity
  const T_float L = x.emission().luminosity();
  T_float L_tot = mpi_reduce(L, std::plus<T_float>());
    
  n_rays_desired = round(n_rays_desired*L/L_tot);
  if(mpi_size()>1)
    printf("Task %d has luminosity %g/%g, gets %ld rays\n",mpi_rank(),L,L_tot,n_rays_desired);
  return n_rays_desired;
}


  
/** Functor class: shoot for the supplied xfer object. */
class mcrx::scatter_shooter {
public:
  scatter_shooter () {};
  bool trivial() const { return false;};
  template<typename T_xfer> void produce (mpi_master<T_xfer>& master) const {
    return master.run(); }
  template<typename T_xfer> void consume (T_xfer& x, tbb::atomic<long>& nr, long nrd) const {
    x.worker_loop(nr, nrd, &T_xfer::create_work_shooting);}
  template<typename T_xfer> void operator() (T_xfer& x, long) const {throw 1;};
  template<typename T_xfer> long calculate_nrays(T_xfer& x, long nrd) const {
    return mcrx::calculate_nrays(x,nrd); };
  template<typename T_xfer> void init(T_xfer& x) const {};
};

/** Functor class: shoots integration rays. */
class mcrx::integration_shooter {
public:
  typedef enum {intensity, emission} mode;

private:
  mode mode_;

public:
  integration_shooter (mode m) : mode_(m) {};
  bool trivial() const { return false;};
  template<typename T_xfer> void produce (mpi_master<T_xfer>& master) const {
    return master.run(); };
  template<typename T_xfer> void consume (T_xfer& x, tbb::atomic<long>& nr, 
					  long nrd) const {
    x.worker_loop(nr, nrd, 
		  boost::function<bool (T_xfer*, tbb::atomic<long>&, long)>
		  ((mode_==intensity) ? 
		   &T_xfer::create_work_intensity_integration :  
		   &T_xfer::create_work_emission_integration)); }
  template<typename T_xfer> void operator() (T_xfer& x, long) const {throw 1;};
  template<typename T_xfer> long calculate_nrays(T_xfer& x, long nrd) const {
    assert(nrd==0); return nrd; };
  template<typename T_xfer> void init(T_xfer& x) const {
    // we need to fill the queue *before* starting the workers
    x.init_ray_integration(); };
};


/** Shoots bidirectional rays. */
class mcrx::bidir_shooter {
private:
  int ns_max;
  int nc_max;
public:
  /** Creates a bidirectional shooter. The maximum number of source
      and camera vertices can be set with ns_max and nc_max. Default
      for both in inf, which gives the full bidirectional method. If
      nc_max=1, we get normal forward method. */
  bidir_shooter (int nsm=blitz::huge(int()),
		 int ncm=blitz::huge(int())): 
    ns_max(nsm), nc_max(ncm) {};
  bool trivial() const { return true;};
  /** \todo The normalization for bidirectional shooting does not take
      into account number of rays like the other shootings. */
  template<typename T_xfer> void operator() (T_xfer& x, long) const {
    x.bidirectional_shoot(ns_max, nc_max);}
  template<typename T_xfer> void produce (mpi_master<T_xfer>& master) const {
    throw 1; }
  template<typename T_xfer> void consume (T_xfer& x, tbb::atomic<long>& nr, long nrd) const {
    throw 1;}
  template<typename T_xfer> void init(T_xfer& x) const {};
  template<typename T_xfer> long calculate_nrays(T_xfer& x, long nrd) const {
    assert(mpi_size()==1); return mcrx::calculate_nrays(x,nrd); };
};

/** Functor class: shoot_isotropic for the supplied xfer object. */
class mcrx::nonscatter_shooter {
public:
  nonscatter_shooter () {};
  bool trivial() const { return true;};
  template<typename T_xfer> void operator() (T_xfer& x, long nrd) const {
    x.shoot_isotropic (nrd);}
  template<typename T_xfer> void produce (mpi_master<T_xfer>& master) const {
    throw 1; }
  template<typename T_xfer> void consume (T_xfer& x, tbb::atomic<long>& nr, long nrd) const {
    throw 1;}
  template<typename T_xfer> long calculate_nrays(T_xfer& x, long nrd) const {
    return mcrx::calculate_nrays(x,nrd); };
  template<typename T_xfer> void init(T_xfer& x) const {};
};


const int cache_line_size = 128;

/** Thread object, executes the shooting loop in the shoot() function.
    The shooter_thread objects operator () executes in the thread,
    which is started by the shoot () function.  The object contains
    its own xfer object to do the radiative transfer and all
    additional necessary information. */
template<typename T_xfer, typename T_shooter>  
class mcrx::shooter_thread {
public:
  char padding [cache_line_size]; //!< To ensure threads don't share cache line

  /** Flag used to ensure that we've made a local copy of the object
      before the thread starts executing. */
  bool execute_;
  bool bind_threads_;

  int thread_number_;
  int bank_size_;

  T_xfer x; //!< Xfer object used by this thread
  /// Shooter functor knows how to tell the xfer object how to shoot a ray.
  const T_shooter s;
  /// Terminator functor knows how to check whether shooting should be
  /// terminated.

  const terminator& t;
  /// Reference to the state of the random number generator.
  T_rng::T_state& rng_state;
  /// Desired number of rays to be shot.
  long n_rays_desired;
  /// Reference to the ray counter, common for all threads.
  tbb::atomic<long>& n_rays;
  int hpm_no;
  std::string hpm_string;
  counter& c;

public:
  shooter_thread(int n, T_xfer& xxx,
		 const T_shooter& ss, const terminator& tt,
		 T_rng::T_state& st, long nrd, tbb::atomic<long>& nr, 
		 bool bt, int bs, int hno,
		 const std::string& hstr, counter& cc):
    execute_(false), thread_number_ (n), x (xxx), s (ss), t (tt), rng_state (st),
    n_rays_desired (nrd), n_rays (nr), bind_threads_(bt), bank_size_(bs),
    hpm_no (hno), hpm_string (hstr),
    c (cc) {};
  
  shooter_thread(const shooter_thread& t) :
    execute_(false), thread_number_(t.thread_number_), x(t.x), s(t.s),
    t(t.t), rng_state(t.rng_state), n_rays_desired(t.n_rays_desired),
    n_rays(t.n_rays), bind_threads_(t.bind_threads_), bank_size_(t.bank_size_),
    hpm_no(t.hpm_no), hpm_string(t.hpm_string),
    c(t.c) {};

  void operator () ();
};


/** End-user function to shoot rays using multiple threads.  This
    very general function spawns threads and shoots, and is adaptable
    to all the cases we are interested in through the use of the
    shooter and terminator functors. The ray tracing set up is passed
    through the xfer object, which knows about the grid, emission and
    emergence objects.*/
template<typename T_xfer, typename T_shooter>  
bool
mcrx::shoot (/// xfer object contains ray-tracing parameters 
	     T_xfer& xx,
	     /// Determines which xfer::shoot function to call.
	     T_shooter s, 
	     /// Determines how the shooting can be interrupted.
	     const terminator& t,
	     /** The state vector which is used to initialize the
                 random number generators in the threads.  Must
                 contain n_threads entries. */
	     std::vector<T_rng::T_state>& rng_states,
	     /// The desired number of total rays to be shot, across
	     /// all threads and tasks.
	     long n_rays_desired,
	     /** Reference to a total ray counter which is set to the
                 final number of rays shot (n_rays_desired unless
                 shooting is interrupted by the terminator). This
                 counter can be started from nonzero if a run is being
                 restarted. */
	     long& n_rays,
	     /// Desired number of threads.
	     int n_threads,
	     /// If true, the threads are bound to CPUs
	     bool bind_threads,
	     /// The size of the banks that determine hyperthreading
	     /// vs real cores.
	     int bank_size,
	     /// If hpm is used, the number used in the call to hpmTstart.
	     int hpm_no,
	     /// If hpm is used, the "title" of the instrumented section.
	     std::string hpm_string)
{
  std::cout << "Spawning " << n_threads << " shooting threads" << std::endl;

  assert (rng_states.size()== n_threads);

  // shooter may need to do initialization 
  s.init(xx);

  // During the calculation, we keep the counter for number of rays in
  // an atomic.
  tbb::atomic<long> n_rays_cur;
  /// \todo need to divide the previously done number of rays too 
  n_rays_cur = n_rays;
  // divide the number of rays onto the different tasks
  long n_rays_this_task=s.calculate_nrays(xx, n_rays_desired);
  mpi_master<xfer<typename T_xfer::T_dust_model, typename T_xfer::T_grid> >
	     master(xx, n_threads, n_rays_cur, n_rays_this_task); 

  // clear the hpm output file by just opening it.
  std::ofstream((hpm_string+"_"+boost::lexical_cast<string>(xx.task())+".hpmout").c_str());

  // create thread objects and
  // spawn threads
  counter c (1000);
  c = n_rays;
  boost::thread_group threads;
  typedef shooter_thread<T_xfer, T_shooter> T_thread;   
  std::vector<boost::shared_ptr<T_thread> > thread_objects;
  for (int i = 0; i < n_threads; ++i) {
    thread_objects.push_back(boost::shared_ptr<T_thread> 
			     (new T_thread
			      (i, xx, s, t,
			       rng_states [i],
			       n_rays_this_task,
			       n_rays_cur, bind_threads, bank_size,
			       hpm_no, hpm_string,
			       c)) );
  }

  // start threads
  for (int i = 0; i < n_threads; ++i) {
    threads.create_thread(boost::ref(*thread_objects[i]));
  }

  if(!s.trivial())
    s.produce(master);

  threads.join_all();

  n_rays = n_rays_cur;

  std::cout << "Shooting complete" << std::endl;

  // Here we strictly can't know whether they were terminated or if
  // terminator was set during the joining... I guess it doesn't
  // matter much, except that it will dump instead of save, but then
  // again that's not necessarily so bad since we want to be quick
  // about it...
  return t();
}


/** Shoots rays for the thread object. Started by the Boost threads
    module when the thread is created. Attempts to bind the thread to
    a core.  */
template<typename T_xfer, typename T_shooter>
void mcrx::shooter_thread<T_xfer, T_shooter>::operator () ()
{
  // trouble here. when the local copy goes out of scope, its copy of
  // the xfer will copy its copy of the emergence to the main
  // emergence object. But then we are left with the xfer object in
  // the first therad object, which has not received any rays and
  // which will attempt to copy its data to the emergence object too,
  // when the thread object is deleted. Also, the copy from local_copy
  // to the main emergence is subject to a race condition, since it
  // happens here, before the threads have exited. For that reason, we
  // don't make a local copy if the xfer object uses a local
  // emergence.

  // See if this is the local copy or not.
  if(!execute_) {
    // first thing is to bind threads and allocations so the xfer
    // copy gets allocated in the correct place
    if(bind_threads_) {
      bind_thread(thread_number_, bank_size_);
      bind_allocations();
    }
    
    shooter_thread<T_xfer, T_shooter> local_copy(*this);
    local_copy.execute_=true;
    local_copy();
    return;
  }

  // Ensure that each thread has a distinct generator by using the
  // thread_number.
  x.rng().make_unique(thread_number_);
  x.rng().get_generator().setState(rng_state);
  x.set_thread_number(thread_number_);
  DEBUG(3,std::cout << "Shooting thread " << thread_number_ << " seed " << x.rng().get_generator().getStateString() << std::endl;);

  // Set the hpm stage names (annoying this is separate from the enum in xfer).
  const int n_stages=13;
  hpm::thread_init(thread_number_, n_stages);
  const char* const stages[n_stages] = 
    {"Worker", "Mainloop", "Forced", "Propagate", "Intensities", "Unforced", 
     "Ray_send", "Send_wait", "Emit", "Locate", "Scatter", "Emerge", 
     "Camera_ray"};
  for(int i=0; i<n_stages; ++i)
    hpm::set_stage_name(i, stages[i]);

  //try {
    if(s.trivial()) {
      //the old way
      while(increment_if_less_than(n_rays, n_rays_desired) && !t())
	s (x, n_rays_desired); // Calls the appropriate xfer::shoot function
    }
    else
      // the shooter object starts the worker loop
      s.consume (x, n_rays, n_rays_desired);
    /*
      }
  catch (...) {
    std::cerr << "Whoops: Uncaught exception in thread!"  << std::endl;
    throw;
  }
    */

    hpm::thread_stop(hpm_string+"_"+boost::lexical_cast<string>(x.task()), 
		     hpm_string+"_"+boost::lexical_cast<string>(x.task())+".hpmout");

  // return random number generator state
  rng_state = x.rng().get_generator().getState();

  DEBUG(1,printf("Worker %d-%d exiting\n",x.task(),thread_number_););
}




#endif
